library(mc2d)
library(Pareto)
library(latex2exp)
library(clue)
######################################################
################## functions #########################
######################################################
recovery_error<- function(true_label, clusters){
  conf_matrix <- table(true_label, clusters)
  alignment <- solve_LSAP(conf_matrix, maximum = TRUE)
  # Apply the alignment to map old cluster labels to new ones
  new_labels <- colnames(conf_matrix)[alignment]
  aligned_TrueLabel <- new_labels[as.integer(true_label)]
  aligned_TrueLabel <- as.integer(aligned_TrueLabel)
  # Print the aligned cluster labels
  #print(aligned_TrueLabel)
  return( sum(1 - (aligned_TrueLabel == clusters))/length(true_label) )
}

###adjusting the labels in estimated cluster to align with the true labels
align_label<- function(true_label,clusters){
  conf_matrix <- table(clusters, true_label)
  alignment <- solve_LSAP(conf_matrix, maximum = TRUE)
  # Apply the alignment to map true labels to new ones
  new_labels <- colnames(conf_matrix)[alignment]
  aligned_cluster_Label <- new_labels[as.integer(clusters)]
  aligned_cluster_Label <- as.integer(aligned_cluster_Label)
  return(aligned_cluster_Label)
}

##refitting theta
refit_theta<- function(clusters, network_data, community_1s, method){
  n = dim(network_data)[1]
  # if (method ==1){
  #   hat_theta<-rep(0, n)
  #   for ( i in 1:n){
  #     c_label<- clusters[i] ##community label 
  #     hat_theta[i]<- sum(network_data[i,]*community_1s[, c_label]) / sqrt( sum( network_data[which(community_1s[, c_label] ==1 ), which(community_1s[, c_label] ==1 )] ) )
  #   }
  # }
  # 
  if (method ==2){
    hat_theta<-rep(0, n)
    for ( i in 1:n){
      c_label<- clusters[i] ##community label 
      hat_theta[i]<- ( sum(network_data[i,]) / ( sum(community_1s[, c_label] * rowSums(network_data) ) ) ) * sqrt( sum( network_data[which(community_1s[, c_label] ==1 ), which(community_1s[, c_label] ==1 )] ) )
    }
  }
  
  if (method ==3){ #cancellation trick for theta
    temp_data<-network_data
    diag(temp_data)<-0 ##diagonal removed
    
    hat_theta<-rep(0, n)
    for ( i in 1:n){
      c_label<- clusters[i]
      hat_label_removed<- community_1s[, c_label]
      hat_label_removed[i] <- 0
      
      temp_ind<-which(hat_label_removed ==1)
      numer <- temp_data[i,temp_ind]%*% (1- temp_data[temp_ind,temp_ind]) %*% temp_data[i,temp_ind] - sum((temp_data[i,temp_ind]^2))
      denom <- (1-temp_data[i, temp_ind])%*%temp_data[temp_ind,temp_ind]%*%(1-temp_data[i, temp_ind])
      # print(c("vector for denom:", temp_data[i,temp_ind]))
      #  print(c("matrix for denom:", temp_data[temp_ind,temp_ind]))
      # print(c("num:",numer))
      #  print(c("denom:",denom))
      hat_theta[i]<- sqrt(numer/denom)
      
    }
  }
  
  return(hat_theta)
}

##refitting P
refit_P <- function(network_data, community_1s, est_theta, method){
  K = dim(community_1s)[2]
  if (method == 1){
    degree_within_com = rep(0,K)
    for (k in 1:K){
      degree_within_com [k] <- sum(network_data[which(community_1s[,k] ==1), which(community_1s[,k] ==1) ])
    }
    hat_P <-t( t( (t(community_1s) %*% network_data %*% community_1s) * sqrt(1/ degree_within_com) ) * sqrt( 1/degree_within_com) )
    diag(hat_P)<-1
  }
  
  if (method == 2){ #cancellation trick for P
    P_num <- t(community_1s) %*% network_data %*% community_1s
    P_dem <- t(community_1s) %*% t( t( (1- network_data )*est_theta ) * est_theta )  %*% community_1s
    hat_P <- P_num/P_dem
    diag(hat_P)<-1
  }
  return(hat_P)
}

SCORE<- function(K, Omega){
  n<- dim(Omega)[1]
  Omega_eigRes <- RSpectra::eigs_sym(Omega, k=K)
  Omega_eigVals <- Omega_eigRes$values
  Omega_eigVecs <- Omega_eigRes$vectors
  
  ####SCORE####
  Omega_xi1=pmax( sign( Omega_eigVecs[1,1])* Omega_eigVecs[,1], 10^(-15));
  Omega_R= Omega_eigVecs[1:n,2:K]*Omega_xi1^(-1);
  return(Omega_R)
}

print_eigenval<-function(K, ordered_Omega_eigVals, ordered_refit_eigenval){
  cat("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      "Evolution of lambda(K+1):\n",
      "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      ordered_Omega_eigVals[K+1] , "\n",
      ordered_refit_eigenval[,K+1] , "\n")
  cat("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      "Evolution of lambda(K+1)/lambda(K):\n",
      "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      ordered_Omega_eigVals[K+1] / ordered_Omega_eigVals[K], "\n",
      ordered_refit_eigenval[,K+1] / ordered_refit_eigenval[,K], "\n")
  
  cat("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      "Evolution of lambda(K+1)/lambda(1):\n",
      "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      ordered_Omega_eigVals[K+1] / ordered_Omega_eigVals[1], "\n",
      ordered_refit_eigenval[,K+1] / ordered_refit_eigenval[,1], "\n")
  
  cat("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      "Evolution of |lambda(K+1) - lambda(K)|/lamdba(1):\n",
      "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% \n",
      abs(ordered_Omega_eigVals[K+1]- ordered_Omega_eigVals[K]) / ordered_Omega_eigVals[1], "\n",
      abs(ordered_refit_eigenval[,K+1] - ordered_refit_eigenval[,K]) / ordered_refit_eigenval[,1], "\n")
}
######################################################
######################################################
######################################################

##############Generate Omega####################
# dimension 
K=5;
n_0 = 500# number of nodes in each community. 
n=n_0*K; #number of nodes
set.seed(66)

##### generate theta #####
#theta_0=runif(n,0.01,2) #setting A
#theta_0=runif(n,0.1,0.8) #setting B
#theta_0=pmin(rPareto(n,10,1), 200); #setting C
theta_0=pmin(rPareto(n,10,1), 100); #setting D
#theta_0=rPareto(n,10,1)


#theta= theta_0
bn = 50
theta= bn* theta_0/sqrt(sum(theta_0^2));
#const=24.5 #SNR
#beta = const/bn  
theta = sample(theta)
beta=0.45 ##1- beta in the paper

#### Pi #####
rep_matrix<- matrix(rep(1, n_0), nrow= n_0)
Pi <- kronecker(diag(K), rep_matrix)
true_label = NULL
for (k in 1:K){
  true_label<- c(true_label, rep(k,n_0))
}

#### Omega ####
P<- matrix(1-beta, K, K) + diag(beta,K,K)
tOmega=t(t(Pi%*%P%*%t(Pi)*theta)*theta) #true low-rank
Omega = tOmega/(1+tOmega) #beta-model Omega


####### Single randomization A ####
A0<- matrix(rbern(n^2, c(Omega)),n,n);
A=matrix(0,n,n);
A[lower.tri(A)]<- A0[lower.tri(A0)];
A= A+t(A);

################ Oracle ################
#####EigenValues&Vectors
tOmega_eigRes <- RSpectra::eigs_sym(tOmega, k=K)
tOmega_eigVals <- tOmega_eigRes$values
tOmega_eigVecs <- tOmega_eigRes$vectors
Omega_eigRes <- RSpectra::eigs_sym(Omega, k=K)
Omega_eigVals <- Omega_eigRes$values
Omega_eigVecs <- Omega_eigRes$vectors


####SCORE####
Omega_xi1=pmax( sign( Omega_eigVecs[1,1])* Omega_eigVecs[,1], 10^(-15));
Omega_R= Omega_eigVecs[1:n,2:K]*Omega_xi1^(-1);
print(c("Low rank", tOmega_eigVals))
print(c("beta_model", Omega_eigVals))

ini_clusters<-kmeans(Omega_R, centers = K, nstart=25)$cluster
aligned_clusters<-align_label(true_label, ini_clusters)
initial_loss<-recovery_error(true_label, ini_clusters)
print(initial_loss)
ini_hat_1<-matrix(rep(0,n*K), nrow=n)
for (k in 1:K){
  ini_hat_1[,k]<-as.integer(aligned_clusters==k)
}

###########################Refitting &SCORE ########################
repe = 10
### 4 methods ####
Method = matrix(0, 4,2)
Method[1,] = c(2,1)
Method[2,] = c(3,1)
Method[3,] = c(2,2)
Method[4,] = c(3,2) # Our proposed method with cancellation trick


############# Empirical Case #########

A_eigRes <- RSpectra::eigs_sym(A, k=K)
A_eigVals <- A_eigRes$values
A_eigVecs <- A_eigRes$vectors
adj_xi1=pmax( sign( A_eigVecs[1,1])* A_eigVecs[,1], 10^(-15));
adj_R= A_eigVecs[1:n,2:K]*adj_xi1^(-1);

adj_ini_clusters<-kmeans(adj_R, centers = K)$cluster
adj_aligned_clusters<-align_label(true_label, adj_ini_clusters)
adj_initial_loss<-recovery_error(true_label, adj_ini_clusters)
print(adj_initial_loss)
adj_ini_hat_1<-matrix(rep(0,n*K), nrow=n)
for (k in 1:K){
  adj_ini_hat_1[,k]<-as.integer(adj_aligned_clusters==k)
}

adj_loss = rep(0, repe)
hat_1s = adj_ini_hat_1 ## also Pi
network_data =A
clusters = adj_aligned_clusters


adj_refit_eigenval<-matrix(0, nrow=repe, ncol=(K+1) )
A_values<-RSpectra::eigs_sym(A, k=K+1)$values
A_values<-A_values[order(-abs(A_values))]
print(A_values)

method = Method[4,] 
for (i in 1:repe){
  ref_theta <- refit_theta(clusters=clusters, network_data = network_data, community_1s = hat_1s, method = method[1])
  ref_P <- refit_P (network_data= network_data, community_1s = hat_1s, est_theta =ref_theta, method =method[2])
  print(c("P_error:", sum(abs(ref_P - P) ), "theta_error:", sum((ref_theta - theta)^2) ) )
  est_tOmega <- t( t( (hat_1s %*% ref_P %*% t(hat_1s)) * ref_theta )* ref_theta ) 
  est_K_inv <- 1+ est_tOmega 
  est_Omega <- est_K_inv * network_data
  est_Omega_R= SCORE(K, est_Omega)

  est_O_eigval<- RSpectra::eigs_sym(est_Omega, k=K+1)$values
  adj_refit_eigenval[i,]<-est_O_eigval[order(-abs(est_O_eigval))]  # top K+1 eigenvalues after refitting
  temp_clusters<-kmeans(est_Omega_R, centers = K, nstart=25 )$cluster
  adj_loss[i] = recovery_error(true_label, temp_clusters)
  print(c(i,"th iteration", adj_loss[i]))
  clusters<-align_label(true_label, temp_clusters) #align the obtained cluster with the true label
  
  #update Pi(or hat_1s)
  for (k in 1:K){
    hat_1s[,k]<-as.integer(clusters==k)
  }
}
# }
print(adj_initial_loss)
print(adj_loss)

error_iteration<- c(adj_initial_loss, adj_loss)


#############################################################
###################### Plot error rate ######################
#############################################################

# 
###########setting A#############
# plot(0:10,error_iteration,
#      pch=20, cex=1.4, col="red", ylim = range(error_iteration+ c(0.05,rep(0,10))), lty=1, xlab = "Iteration", ylab="Error rate", )
# lines(0:repe, error_iteration, col="red", lty=1,lwd=2)
# abline(h=error_iteration[1], lty=2,lwd=3, col="forestgreen")
# title(main="(A)", line=0.8) 

##########setting B#############
# plot(0:10,error_iteration,
#      pch=20, cex=1.4, col="red", ylim = range(error_iteration+ c(0.02,rep(-0.01,10))), lty=1, xlab = "Iteration", ylab="Error rate", )
# lines(0:repe, error_iteration, col="red", lty=1,lwd=2)
# abline(h=error_iteration[1], lty=2,lwd=3, col="forestgreen")
# title(main="(B)", line=0.8) 

#########setting C#############
# plot(0:10,error_iteration,
#      pch=20, cex=1.4, col="red", ylim = range(error_iteration+ c(0.05,rep(-0.01,10))), lty=1, xlab = "Iteration", ylab="Error rate", )
# lines(0:repe, error_iteration, col="red", lty=1,lwd=2)
# abline(h=error_iteration[1], lty=2,lwd=3, col="forestgreen")
# title(main="(C)", line=0.8) 

#########setting D#############
# plot(0:10,error_iteration,
#      pch=20, cex=1.4, col="red", ylim = range(error_iteration+ c(0.05,rep(-0.01,10))), lty=1, xlab = "Iteration", ylab="Error rate", )
# lines(0:repe, error_iteration, col="red", lty=1,lwd=2)
# abline(h=error_iteration[1], lty=2,lwd=3, col="forestgreen")
# title(main="(D)", line=0.8) 

